{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 304 - Medical Entity Extraction with a BiLSTM\n",
    "\n",
    "In this tutorial we use a Bidirectional LSTM entity extractor from the MMLSPark\n",
    "model downloader to extract entities from PubMed medical abstracts\n",
    "\n",
    "Our goal is to identify useful entities in a block of free-form text.  This is a\n",
    "nontrivial task because entities might be referenced in the text using variety of\n",
    "synonymns, abbreviations, or formats. Our target output for this model is a set\n",
    "of tags that specify what kind of entity is referenced. The model we use was\n",
    "trained on a large dataset of publically tagged pubmed abstracts. An example\n",
    "annotated sequence is given below, \"O\" represents no tag:\n",
    "\n",
    "|I-Chemical | O   |I-Chemical  | O   | O   |I-Chemical | O   |I-Chemical  |",
    " O   | O      | O   | O   |I-Disease |I-Disease| O   | O    |\n",
    "|:---:      |:---:|:---:       |:---:|:---:|:---:      |:---:|:---:       |",
    ":---:|:---:   |:---:|:---:|:---:     |:---:    |:---:|:---: |\n",
    "|Baricitinib| ,   |Methotrexate| ,   | or  |Baricitinib|Plus |Methotrexate|",
    " in  |Patients|with |Early|Rheumatoid|Arthritis| Who |Had...|\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from mmlspark.cntk import CNTKModel\n",
    "from mmlspark.downloader import ModelDownloader\n",
    "from pyspark.sql.functions import udf, col\n",
    "from pyspark.sql.types import IntegerType, ArrayType, FloatType, StringType\n",
    "from pyspark.sql import Row\n",
    "\n",
    "from os.path import abspath, join\n",
    "import numpy as np\n",
    "from nltk.tokenize import sent_tokenize, word_tokenize\n",
    "import os, tarfile, pickle\n",
    "import urllib.request\n",
    "import nltk"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Get the model and extract the data."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "mml-deploy": "hdinsight",
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "modelName = \"BiLSTM\"\n",
    "modelDir = \"models\"\n",
    "if not os.path.exists(modelDir): os.makedirs(modelDir)\n",
    "d = ModelDownloader(spark, \"dbfs:///\" + modelDir)\n",
    "modelSchema = d.downloadByName(modelName)\n",
    "nltk.download(\"punkt\", \"/dbfs/nltkdata\")\n",
    "nltk.data.path.append(\"/dbfs/nltkdata\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "mml-deploy": "local",
    "collapsed": false
   },
   "outputs": [],
   "source": [
    "modelName = \"BiLSTM\"\n",
    "modelDir = abspath(\"models\")\n",
    "if not os.path.exists(modelDir): os.makedirs(modelDir)\n",
    "d = ModelDownloader(spark, \"file://\" + modelDir)\n",
    "modelSchema = d.downloadByName(modelName)\n",
    "nltk.download(\"punkt\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Download the embeddings\n",
    "\n",
    "We use the nltk punkt sentence and word tokenizers and a set of embeddings trained on PubMed Articles"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "wordEmbFileName = \"WordEmbeddings_PubMed.pkl\"\n",
    "pickleFile = join(abspath(\"models\"), wordEmbFileName)\n",
    "if not os.path.isfile(pickleFile):\n",
    "    urllib.request.urlretrieve(\"https://mmlspark.blob.core.windows.net/datasets/\" + wordEmbFileName, pickleFile)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Load the embeddings and create functions for encoding sentences"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "pickleContent = pickle.load(open(pickleFile, \"rb\"), encoding=\"latin-1\")\n",
    "wordToIndex = pickleContent[\"word_to_index\"]\n",
    "wordvectors = pickleContent[\"wordvectors\"]\n",
    "classToEntity = pickleContent[\"class_to_entity\"]\n",
    "\n",
    "nClasses = len(classToEntity)\n",
    "nFeatures = wordvectors.shape[1]\n",
    "maxSentenceLen = 613"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "content = \"Baricitinib, Methotrexate, or Baricitinib Plus Methotrexate in Patients with Early Rheumatoid\\\n",
    " Arthritis Who Had Received Limited or No Treatment with Disease-Modifying-Anti-Rheumatic-Drugs (DMARDs):\\\n",
    " Phase 3 Trial Results. Keywords: Janus kinase (JAK), methotrexate (MTX) and rheumatoid arthritis (RA) and\\\n",
    " Clinical research. In 2 completed phase 3 studies, baricitinib (bari) improved disease activity with a\\\n",
    " satisfactory safety profile in patients (pts) with moderately-to-severely active RA who were inadequate\\\n",
    " responders to either conventional synthetic1 or biologic2DMARDs. This abstract reports results from a\\\n",
    " phase 3 study of bari administered as monotherapy or in combination with methotrexate (MTX) to pts with\\\n",
    " early active RA who had limited or no prior treatment with DMARDs. MTX monotherapy was the active comparator.\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "sentences = sent_tokenize(content)\n",
    "df = spark.createDataFrame(enumerate(sentences), [\"index\",\"sentence\"])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "mml-deploy": "hdinsight",
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "# Add the tokenizers to all worker nodes\n",
    "def prepNLTK(partition):\n",
    "    nltk.data.path.append(\"/dbfs/nltkdata\")\n",
    "    return partition\n",
    "\n",
    "df = df.rdd.mapPartitions(prepNLTK).toDF()\n",
    "df.count()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "tokenizeUDF = udf(word_tokenize, ArrayType(StringType()))\n",
    "df = df.withColumn(\"tokens\",tokenizeUDF(\"sentence\"))\n",
    "\n",
    "countUDF = udf(len, IntegerType())\n",
    "df = df.withColumn(\"count\",countUDF(\"tokens\"))\n",
    "\n",
    "def wordToEmb(word):\n",
    "    return wordvectors[wordToIndex.get(word.lower(), wordToIndex[\"UNK\"])]\n",
    "\n",
    "def featurize(tokens):\n",
    "    X = np.zeros((maxSentenceLen, nFeatures))\n",
    "    X[-len(tokens):,:] = np.array([wordToEmb(word) for word in tokens])\n",
    "    return [float(x) for x in X.reshape(maxSentenceLen, nFeatures).flatten()]\n",
    "\n",
    "featurizeUDF = udf(featurize,  ArrayType(FloatType()))\n",
    "df = df.withColumn(\"features\", featurizeUDF(\"tokens\"))\n",
    "df.show()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Run the CNTKModel"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "model = CNTKModel() \\\n",
    "    .setModelLocation(modelSchema.uri) \\\n",
    "    .setInputCol(\"features\") \\\n",
    "    .setOutputCol(\"probs\") \\\n",
    "    .setOutputNodeIndex(0) \\\n",
    "    .setMiniBatchSize(1)\n",
    "\n",
    "df = model.transform(df).cache()\n",
    "df.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def probsToEntities(probs, wordCount):\n",
    "    reshaped_probs = np.array(probs).reshape(maxSentenceLen, nClasses)\n",
    "    reshaped_probs = reshaped_probs[-wordCount:,:]\n",
    "    return [classToEntity[np.argmax(probs)] for probs in reshaped_probs]\n",
    "\n",
    "toEntityUDF = udf(probsToEntities,ArrayType(StringType()))\n",
    "df = df.withColumn(\"entities\", toEntityUDF(\"probs\", \"count\"))\n",
    "df.show()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Show the annotated text"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Color Code the Text based on the entity type\n",
    "colors = {\n",
    "    \"B-Disease\": \"blue\",\n",
    "    \"I-Disease\":\"blue\",\n",
    "    \"B-Drug\":\"lime\",\n",
    "    \"I-Drug\":\"lime\",\n",
    "    \"B-Chemical\":\"lime\",\n",
    "    \"I-Chemical\":\"lime\",\n",
    "    \"O\":\"black\",\n",
    "    \"NONE\":\"black\"\n",
    "}\n",
    "\n",
    "def prettyPrint(words, annotations):\n",
    "    formattedWords = []\n",
    "    for word,annotation in zip(words,annotations):\n",
    "        formattedWord = \"<font size = '2' color = '{}'>{}</font>\".format(colors[annotation], word)\n",
    "        if annotation in {\"O\",\"NONE\"}:\n",
    "            formattedWords.append(formattedWord)\n",
    "        else:\n",
    "            formattedWords.append(\"<b>{}</b>\".format(formattedWord))\n",
    "    return \" \".join(formattedWords)\n",
    "\n",
    "prettyPrintUDF = udf(prettyPrint, StringType())\n",
    "df = df.withColumn(\"formattedSentence\", prettyPrintUDF(\"tokens\", \"entities\")) \\\n",
    "       .select(\"formattedSentence\")\n",
    "\n",
    "sentences = [row[\"formattedSentence\"] for row in df.collect()]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": false
   },
   "outputs": [],
   "source": [
    "from IPython.core.display import display, HTML\n",
    "for sentence in sentences:\n",
    "    display(HTML(sentence))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Example text used in this demo has been taken from:\n",
    "\n",
    "Fleischmann R, Takeuchi T, Schlichting DE, Macias WL, Rooney T, Gurbuz S, Stoykov I,\n",
    "Beattie SD, Kuo WL, Schiff M. Baricitinib, Methotrexate, or Baricitinib Plus Methotrexate\n",
    "in Patients with Early Rheumatoid Arthritis Who Had Received Limited or No Treatment with\n",
    "Disease-Modifying Anti-Rheumatic Drugs (DMARDs): Phase 3 Trial Results [abstract].\n",
    "Arthritis Rheumatol. 2015; 67 (suppl 10).\n",
    "http://acrabstracts.org/abstract/baricitinib-methotrexate-or-baricitinib-plus-methotrexate",
    "-in-patients-with-early-rheumatoid-arthritis-who-had-received-limited-or-no-treatment-with",
    "-disease-modifying-anti-rheumatic-drugs-dmards-p/.\n",
    "Accessed August 18, 2017."
   ]
  }
 ],
 "metadata": {
  "anaconda-cloud": {},
  "kernelspec": {
   "display_name": "Python [default]",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.3"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
